Skip to content

[Pyt][Common] Enabling/Guarding sm120 support (non - attention)#2833

Open
KshitijLakhani wants to merge 33 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/maint/sm120-pyt-cpp-support
Open

[Pyt][Common] Enabling/Guarding sm120 support (non - attention)#2833
KshitijLakhani wants to merge 33 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/maint/sm120-pyt-cpp-support

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented Apr 3, 2026

Description

This PR is a follow up to : #2693.

PR #2693 aimed to enable/guard PyT attention for sm120
This PR aims to enable/guard non-attention for sm120 (and a small attn related regression fix)

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Runtime/backend guards for SM120 correctness

  • Disabled gated TMA backward kernels to avoid kernel launch failures tied to shared-memory constraints.
  • Forced unfused NVFP4 RHT path to avoid fused RHT shared-memory/resource overreach.
  • Disabled NVFP4 stochastic rounding on backend paths in csrc/quantizer.cpp due to unsupported .rs PTX.
  • Added grouped NVFP4 fallback in cast extension (csrc/extensions/cast.cpp) to use safer per-split processing.
  • Added grouped GEMM runtime guard in gemm/cublaslt_grouped_gemm.cu because cuBLASLt grouped GEMM heuristic returns unsupported (for affected BF16/FP8 cases).

General Bug fix (not SM120 specific)
I stumbled upon this bug specifically when I was testing on SM120, but it is an arch agnostic fix.

  • Fixed MXFP8 CAST_DBIAS shared-memory handoff race
    • Ensured async shared->global source consumption is complete and all warps reach safe reuse point

NVFP4 grouped quantization layout consistency for SM120

  • Aligned grouped NVFP4 metadata with actual SM120 fallback output layout in csrc/quantizer.cpp:
    • default: metadata follows optimize_for_gemm,
    • SM120 grouped fallback (first_dims present): force unswizzled metadata.
  • Propagated grouped layout metadata into split tensor views in grouped_tensor_storage.py so split tensors inherit true grouped layout state.
  • Updated grouped NVFP4 tests in test_nvfp4_group_quantize_graph_safe.py to compare against metadata-selected reference layout and use scoped SM120 tolerance behavior.

Test changes (SM120 specific)

  • NVFP4 SR tests: in test_nvfp4_sr_quantize.py, changed SM120 expectation from SR < RN to numerical equivalence (assert_close) because SR is disabled on SM120 backend.
  • FP8 CS numerics: in run_layer_with_overlap.py, added SM120-only looser tolerance for fp8_current_scaling (rtol=0.4, atol=0.25) in deterministic fallback backend scenarios (I borrowed these tolerances from the corresponding distributed test file run_numerics.py)
  • BF16 multi-layer overlap numerics: added narrowly-scoped SM120 tolerance relaxation (rtol=0.05, atol=0.01) for TransformerLayer, multi-layer, overlap_rs_dgrad when deterministic mode routes away from fused attention.
  • THD-vs-dense tolerance and grouped GEMM skips in test_numerics.py, C++ grouped GEMM operator tests, and PyTorch grouped GEMM numerics to match explicit SM120 unsupported/runtime-guarded paths.
  • Skipped SM120 NVFP4 paged-stashing grouped-quantize case due to observed IMA in current kernel assumptions for paged layouts.

SM120 coverage/test harness updates

  • Made custom-recipe grouped-linear shapes 16-aligned on SM120 because the SM120 FP8 GEMM path enforces leading-dimension alignment (lda % 16 == 0) in backward.
  • Narrow distributed debug/tolerance helper updates in run_distributed.py and related tests for observed SM120 outlier behavior.
  • Relaxed one NVFP4 bias-grad check (single element outlier in ffn1.bias.grad exceeded prior absolute tolerances) for SM120 only

Fused attention SM120 regression fix
Reinstated lost SM120 conditionals in fused_attn_f16_arbitrary_seqlen.cu (This was likely lost during conflict resolution when merging of PR 2677):

  • restored SM120-specific behavior for stats stride selection (use_ragged_stats path),
  • restored SM120-aware output_S shape handling for THD + cuDNN >= 9.6.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Apr 3, 2026
@KshitijLakhani KshitijLakhani changed the title [Pyt][Common Enabling/Guarding sm120 support (non - attention) [Pyt][Common] Enabling/Guarding sm120 support (non - attention) Apr 3, 2026
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch 2 times, most recently from 59ab765 to 5cbb074 Compare April 10, 2026 07:40
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch 4 times, most recently from b01b227 to ccf0da4 Compare April 22, 2026 07:19
@KshitijLakhani KshitijLakhani marked this pull request as ready for review April 22, 2026 22:32
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 22, 2026

Greptile Summary

This PR guards and enables SM120 (Blackwell B200) support across non-attention PyTorch TE paths, following up on the attention-specific PR #2693. It also carries an arch-agnostic MXFP8 CAST_DBIAS shared-memory race fix and restores SM120 conditionals in fused attention that were lost during a prior merge conflict.

  • Runtime guards: disables gated TMA backward kernels, stochastic-rounding PTX paths, and fused grouped-NVFP4/RHT kernels on SM120; blocks cuBLASLt grouped GEMM on SM120 and SM121; disables FlashAttention 4 on SM120.
  • NVFP4 grouped-quantize fallback: on SM120, replaces the fused grouped kernel with a per-split unfused path and post-cast in-place swizzle to honor optimize_for_gemm; propagates the actual grouped layout into split tensor views.
  • Test / tolerance updates: adds SM120-scoped tolerance relaxations, grouped-GEMM skips, and SR-equivalence assertions across multiple test files.

Confidence Score: 4/5

Safe to merge for SM120 hardware, but has a concrete gap for SM121 devices where several backend guards are incomplete.

The C++ runtime (cublaslt_grouped_gemm.cu) and C++ test skip (test_grouped_gemm.cu) both explicitly block SM121 alongside SM120, but is_sm120_device() and every Python-level capability check (== (12, 0)) only detect SM120. On SM121 hardware, the grouped NVFP4 fallback and stochastic-rounding disable won't activate, and the Python grouped-GEMM test skips won't fire — causing the C++ NVTE_CHECK to throw rather than producing a clean pytest skip. The rest of the changes (MXFP8 race fix, TMA kernel disable, FlashAttention 4 guard, fused-attention regression restore) look correct and well-scoped.

transformer_engine/pytorch/csrc/util.h (is_sm120_device), transformer_engine/pytorch/csrc/extensions/cast.cpp (SM120 fallback condition), tests/pytorch/test_fusible_ops.py (SM120 capability checks)

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/util.h Adds is_sm120_device() helper checking sm_arch == 120 exactly; misses SM121, which is already blocked at C++ runtime in grouped GEMM and C++ test skips.
transformer_engine/pytorch/csrc/extensions/cast.cpp Large addition of SM120 grouped NVFP4 fallback (per-split path), post-cast in-place swizzle, and SR disable; fallback only activates when first_dims is present — unclear if graph-safe (no first_dims) path is safe on SM120.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Adds NVTE_CHECK blocking SM120 and SM121 for grouped GEMM; straightforward runtime guard consistent with C++ test skip.
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh Adds cp_async_bulk_wait_group_read + __syncthreads() before parity flip to fix MXFP8 CAST_DBIAS shared-memory handoff race; arch-agnostic fix that looks correct.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Restores two SM120-specific conditionals (stats stride and output_S shape) lost during prior merge conflict resolution.
tests/pytorch/test_fusible_ops.py Skips grouped cuBLASLt GEMM tests on SM120 via (12, 0) capability check; gap: won't skip on SM121 which is also blocked at C++ runtime, causing a throw rather than clean skip.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["group_quantize() called"] --> B{SM120 device?}
    B -- "No" --> C["group_quantize_nvfp4_impl\n(fused grouped kernel)"]
    B -- "Yes + first_dims present" --> D["SM120 fallback:\nsplit_quantize_nvfp4_impl\n(per-split unfused)"]
    B -- "Yes, no first_dims" --> C
    D --> E{optimize_for_gemm?}
    E -- "No" --> F["Compact-layout scales emitted"]
    E -- "Yes" --> G["Compact-layout cast\n→ post-cast in-place swizzle"]
    G --> H["Split tensors inherit grouped _with_gemm_swizzled_scales"]
    F --> H
Loading

Reviews (15): Last reviewed commit: "Relax sanity atol for SM120 + NVFP4 quan..." | Re-trigger Greptile

Comment thread transformer_engine/common/common.cu Outdated
Comment on lines +290 to +295
// KL: test function for CC 120
bool is_supported_by_CC_120() {
int deviceComputeCapability = cuda::sm_arch(cuda::current_device());

return deviceComputeCapability == 120;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Debug/WIP comment and misleading function name

The // KL: test function for CC 120 comment should be removed before merging — it reads as a personal debug note rather than production documentation.

More importantly, the name is_supported_by_CC_120() is semantically inconsistent with is_supported_by_CC_100(). is_supported_by_CC_100 returns >= 100 (meaning "supported by CC 100 or newer"), so by analogy is_supported_by_CC_120 would imply >= 120. However the implementation returns == 120 (exclusively SM120). Every call site uses this to disable a feature on SM120, not to enable something on SM120+. A name like is_exactly_CC_120() or is_CC_120_arch() would prevent future readers from misinterpreting the range semantics.

Comment thread transformer_engine/common/cast/dispatch/gated.cuh Outdated
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from 440ba8b to 4aed9e9 Compare April 23, 2026 18:45
Comment thread tests/pytorch/test_grouped_quantize_fp8_current_scaling.py Outdated
Comment thread transformer_engine/common/cast/grouped_fp8_current_scaling.cu Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/grouped_fp8_bindings.cpp Outdated
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from 0b00fef to a95ba1c Compare April 24, 2026 18:41
@KshitijLakhani KshitijLakhani added the enhancement New feature or request label Apr 28, 2026
Comment on lines +69 to +72
/*! \brief Check whether the current CUDA device is SM120. */
inline bool is_sm120_device() {
return transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) == 120;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be checking for any SM 12.X arch?

Suggested change
/*! \brief Check whether the current CUDA device is SM120. */
inline bool is_sm120_device() {
return transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) == 120;
}
/*! \brief Check whether the current CUDA device is SM12X. */
inline bool is_sm12x_device() {
return transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) / 10 == 12;
}

This pattern shows up throughout this PR.

Comment on lines +978 to +979
# Use the actual grouped-output layout. This can differ from the requested
# quantizer flag if the backend produces a different layout (e.g. sm120)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment (and the other one below) seems wrong to me. The contract is that I give tex.group_quantize a quantizer, and it gives me a matching grouped tensor. tex.group_quantize might internally have a fused or unfused implementation based on the SM arch, but externally I don't care since the results are the same.

Comment on lines +1939 to +1940
const bool with_gemm_swizzled_scales =
this->optimize_for_gemm && !enable_sm120_grouped_nvfp4_fallback;
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of quantizers is to hide details of the recipes and supported kernel fusions. The contract is if the quantizer has optimize_for_gemm=True, then the quantized tensor has swizzled scales. The caller does not need to care or do any extra work depending on their system (or at least, they should get an error message). We should remove this logic and instead perform an unfused cast + swizzle in the quantize functions.

Comment on lines +2239 to +2243
// Disable stochastic-rounding FP4 cast path for SM120, which relies on arch-specific PTX
// instructions.
const bool sm120_device = is_sm120_device();
const bool use_stochastic_rounding = this->stochastic_rounding && !sm120_device;
quant_config.set_stochastic_rounding(use_stochastic_rounding);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should error out rather than silently ignoring user instructions:

Suggested change
// Disable stochastic-rounding FP4 cast path for SM120, which relies on arch-specific PTX
// instructions.
const bool sm120_device = is_sm120_device();
const bool use_stochastic_rounding = this->stochastic_rounding && !sm120_device;
quant_config.set_stochastic_rounding(use_stochastic_rounding);
const bool use_stochastic_rounding = this->stochastic_rounding;
if (use_stochastic_rounding && is_sm120_device()) {
NVTE_ERROR("NVFP4 does not support stochastic rounding on SM 12X");
}
quant_config.set_stochastic_rounding(use_stochastic_rounding);

Comment on lines +95 to +97
// The returned vector is used by NVFP4 grouped-quantize to split the input
// tensor into per-group sub-tensors.
// Currently, only used for SM120 NVFP4 grouped-quantize fallback.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I guess it's not that important since this is an internal helper function, but comments like this become wrong very quickly.

Comment thread tests/pytorch/test_fusible_ops.py Outdated
Comment on lines +3317 to +3322
# SM120 NVFP4 can show a tiny outlier for single element in this bias grad entry while all
# other checks stay within the existing loose sanity tolerances.
b1_tols = tols
if quantization == "nvfp4" and torch.cuda.get_device_capability() == (12, 0):
b1_tols = {"rtol": tols["rtol"], "atol": 0.55}
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **b1_tols)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bug seems like something we should fix, not hackily work around. Do we have any more info?

Suggested change
# SM120 NVFP4 can show a tiny outlier for single element in this bias grad entry while all
# other checks stay within the existing loose sanity tolerances.
b1_tols = tols
if quantization == "nvfp4" and torch.cuda.get_device_capability() == (12, 0):
b1_tols = {"rtol": tols["rtol"], "atol": 0.55}
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **b1_tols)
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)

Comment thread tests/pytorch/test_custom_recipe.py Outdated
Comment on lines +122 to +123
# Use 16-aligned splits on SM120 to satisfy FP8 GEMM leading-dimension requirements in backward.
is_sm120 = torch.cuda.get_device_capability() == (12, 0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like how SM 120 logic is spilling out into unrelated tests. I'd prefer just increasing the batch size so it supports all cases. Similar for the other change in this file.

Comment on lines +31 to +32
# SM120 currently disables NVFP4 stochastic rounding in backend paths,
# so SR and RN should be numerically equivalent.
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd expect a function called _assert_sr_vs_rn_behavior to assert correct behavior in stochastic rounding vs round-to-nearest. A more accurate name would be something cumbersome like _assert_sr_setting_vs_true_rn_behavior, which is a sign of a design mistake (silently suppressing stochastic rounding rather than erroring out). One reason to put effort into choosing accurate names is that good names impose a tax on bad design.

Comment on lines +563 to +569
if (
opts.quantization == "fp8_current_scaling"
and is_sm120
and is_deterministic_mode
):
# SM120 deterministic mode disables fused attn, so rt uses alternate attn backends.
# Combined with FP8 CS, this path needs the looser distributed fp8_cs tolerance policy.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the discrepancy is due to changes in the attention backend, we should only relax the tols with MultiheadAttention and TransformerLayer.

Comment on lines +53 to +54
# SM120: distributed column-parallel path may show a single-element
# activation outlier slightly above default fp32 atol, while grads match.
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a proper bug. If we run on SM 12.0, we want the test to fail rather than giving us a false pass.

…p8::cast_gated_bwd kernel as sm120 shmem requirements are insufficient

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
KshitijLakhani and others added 28 commits May 29, 2026 18:33
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…s Flash and not Fused

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…MM lda constraints

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…debug test activation comparisons

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
- Route grouped NVFP4 with first_dims through SM120 fallback split quantize path.
- Ensure grouped tensor swizzle metadata reflects actual runtime layout
- Propagate grouped layout metadata to split tensor views instead of re-deriving from quantizer flags.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
- Select expected scale reference layout from backend-reported _with_gemm_swizzled_scales.
- Assert grouped/split metadata consistency before validating scales.
- Apply SM120-only tolerance relaxation for scale comparisons and skip unsupported SM120 paged-stashing cas

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
- SM120 backend currently disables NVFP4 stochastic rounding, so SR no longer outperforms RN.
- Update SR assertions to use close-equality on SM120 and keep strict SR<RN checks for sm!=120.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…shape that was lost in an earlier PR's merge conflict

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…tn backend

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…unfusd cast for sm120 when doing the group quantize for nvfp4

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…ed in fusible-ops tests

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…orm_mlp

  - On SM120 the NVFP4 cast falls back to RN (no SR PTX), the grouped row+col
  RHT fusion is split into unfused passes, and gated-act kernels run the
  non-TMA path.
  - Stacked through LayerNorm + Linear + SwiGLU + Linear, this
  widens the worst-case per-element bwd-pass diff (mainly ffn1.bias.grad /
  ffn2.weight.grad with bias=True) past the loose sanity atol=0.5. Bump
  atol to 0.75 only when (SM120, NVFP4, quantized_compute=True).

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from e33fef9 to b5ddaaa Compare May 29, 2026 18:44
Comment on lines +71 to +72
return transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) == 120;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 SM121 detection gap leaves several guards incomplete

is_sm120_device() only returns true for sm_arch == 120, but the C++ runtime guard in cublaslt_grouped_gemm.cu explicitly blocks both SM120 and SM121 (sm != 120 && sm != 121), and the C++ test skip in test_grouped_gemm.cu also covers cc == 121. On an SM121 device, this helper returns false, so:

  • The NVFP4 grouped-quantize SM120 fallback in cast.cpp won't trigger → the fused grouped kernel is invoked on a device the guard above already characterised as unsupported.
  • Stochastic-rounding is not disabled in quantizer.cpp, risking unsupported PTX execution.
  • Python-level grouped-GEMM skips in test_fusible_ops.py check (12, 0) and will not fire on SM121, but the C++ NVTE_CHECK will throw → test failure instead of clean skip.

If SM121 has the same shared-memory / PTX constraints as SM120, the helper and every Python capability check should also cover == 121 (or use a range, e.g. sm_arch >= 120 && sm_arch < 130).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants